{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1a1cab6c",
   "metadata": {},
   "source": [
    "# Comparison Between TreeValue and Jax LibTree"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9426f1e",
   "metadata": {},
   "source": [
    "In this section, we will take a look at the feature and performance of the [jax-libtree](https://jax.readthedocs.io/en/latest/pytrees.html) library, which is developed by Google."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4b6d16",
   "metadata": {},
   "outputs": [],
   "source": [
    "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1085059",
   "metadata": {},
   "source": [
    "## Mapping Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ded33f7",
   "metadata": {},
   "source": [
    "### TreeValue's Mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9abbed",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import mapping, FastTreeValue\n",
    "\n",
    "t = FastTreeValue(_TREE_DATA_1)\n",
    "mapping(t, lambda x: x ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc3420b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit mapping(t, lambda x: x ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "289e14c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping(t, lambda x, p: (x ** 2, p))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fce10189",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit mapping(t, lambda x, p: (x ** 2, p))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18d4346d",
   "metadata": {},
   "source": [
    "### pytree's tree_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b8e706",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_map\n",
    "\n",
    "tree_map(lambda x: x ** 2, _TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41fa094b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_map(lambda x: x ** 2, _TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36191201",
   "metadata": {},
   "source": [
    "## Flatten and Unflatten Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a733f250",
   "metadata": {},
   "source": [
    "### TreeValue's Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "676f03c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import flatten, flatten_keys, flatten_values\n",
    "\n",
    "t_flatted = flatten(t)\n",
    "t_flatted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ca0363d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit flatten(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d3bd2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import flatten_keys\n",
    "\n",
    "flatten_keys(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58fbf8c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit flatten_keys(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1448e341",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import flatten_values\n",
    "\n",
    "flatten_values(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "956cf1c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit flatten_values(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f35e3a6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import unflatten\n",
    "\n",
    "unflatten(t_flatted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52980ab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit unflatten(t_flatted)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ff80a6c",
   "metadata": {},
   "source": [
    "### pytree's Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e89fc565",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_flatten\n",
    "\n",
    "leaves, treedef = tree_flatten(_TREE_DATA_1)\n",
    "print('Leaves:', leaves)\n",
    "print('Treedef:', treedef)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d59c4d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_flatten(_TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb1318c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_unflatten\n",
    "\n",
    "tree_unflatten(treedef, leaves)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "638b5144",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_unflatten(treedef, leaves)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffd1522b",
   "metadata": {},
   "source": [
    "## All Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "112c91a1",
   "metadata": {},
   "source": [
    "### TreeValue's Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b9c1c51",
   "metadata": {},
   "outputs": [],
   "source": [
    "all(flatten_values(t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb8b0b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit all(flatten_values(t))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86462abd",
   "metadata": {},
   "source": [
    "### pytree.tree_all's performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64d39484",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a8427db",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_all(_TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4a4cd59",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_all(_TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0221182",
   "metadata": {},
   "source": [
    "## Reduce Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "646b4a2e",
   "metadata": {},
   "source": [
    "### TreeValue's Reduce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105e9001",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import reduce\n",
    "\n",
    "def _flatten_reduce(tree):\n",
    "    values = flatten_values(tree)\n",
    "    return reduce(lambda x, y: x + y, values)\n",
    "\n",
    "_flatten_reduce(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71c145c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit _flatten_reduce(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20b34964",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _flatten_reduce_with_init(tree):\n",
    "    values = flatten_values(tree)\n",
    "    return reduce(lambda x, y: x + y, values, 0)\n",
    "\n",
    "_flatten_reduce_with_init(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca8562c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit _flatten_reduce_with_init(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12b2a28a",
   "metadata": {},
   "source": [
    "### pytree.tree_reduce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e9237e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_reduce\n",
    "\n",
    "tree_reduce(lambda x, y: x + y, _TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c5cfda",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e79fb73b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af47dd74",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e58b70f",
   "metadata": {},
   "source": [
    "## Structure Transpose"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4a31f9e",
   "metadata": {},
   "source": [
    "### Subside and Rise in TreeValue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6af689f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import subside\n",
    "\n",
    "value = {\n",
    "    'a': FastTreeValue({'a': 1, 'b': {'x': 2, 'y': 3}}),\n",
    "    'b': FastTreeValue({'a': 10, 'b': {'x': 20, 'y': 30}}),\n",
    "    'c': {\n",
    "        'x': FastTreeValue({'a': 100, 'b': {'x': 200, 'y': 300}}),\n",
    "        'y': FastTreeValue({'a': 400, 'b': {'x': 500, 'y': 600}}),\n",
    "    },\n",
    "}\n",
    "subside(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b85caa84",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit subside(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f304d286",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import raw, rise\n",
    "\n",
    "value = FastTreeValue({\n",
    "    'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),\n",
    "    'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),\n",
    "    'c': {\n",
    "        'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),\n",
    "        'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),\n",
    "    },\n",
    "})\n",
    "rise(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d34321",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit rise(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ed4793",
   "metadata": {},
   "outputs": [],
   "source": [
    "vt = {'a': None, 'b': {'x': None, 'y': None}}\n",
    "rise(value, template=vt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6ad3639",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit rise(value, template=vt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19ee26c6",
   "metadata": {},
   "source": [
    "### pytree.tree_transpose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f24f0f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_structure, tree_transpose\n",
    "\n",
    "sto = tree_structure({'a': 1, 'b': 2, 'c': {'x': 3, 'y': 4}})\n",
    "sti = tree_structure({'a': 1, 'b': {'x': 2, 'y': 3}})\n",
    "\n",
    "value = (\n",
    "    {'a': 1, 'b': {'x': 2, 'y': 3}},\n",
    "    {\n",
    "        'a': {'a': 10, 'b': {'x': 20, 'y': 30}},\n",
    "        'b': [\n",
    "            {'a': 100, 'b': {'x': 200, 'y': 300}},\n",
    "            {'a': 400, 'b': {'x': 500, 'y': 600}},\n",
    "        ],\n",
    "    }\n",
    ")\n",
    "tree_transpose(sto, sti, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d04072f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit tree_transpose(sto, sti, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2a580a3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
